# This code illustrates the effect of local updates on the fixed points, emphasizing what happens with a sharp outlier. Used to generate Figure 3

import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
np.random.seed(42)

# Experiment parameters
M = 10                  # Number of machines
eta = 0.01              # Step-size
K_vals = [1, 2, 5, 10, 20, 50, 100]  # Local step counts

# Generate machine-specific Hessians (A_m) and optimal points (x_m^*)
A_matrices = []
x_m_stars = []

# 9 well-conditioned machines
for _ in range(M - 1):
    A_matrices.append(np.array([[1, 0], [0, 1]]))  # Moderate curvature
    x_m_stars.append(np.random.randn(2) * 0.5)     # Random optima near origin

# 1 outlier machine with sharp curvature
A_matrices.append(np.array([[100, 0], [0, 1]]))
x_m_stars.append(np.array([3.0, 0.0]))  # Farther optima for the outlier

# Compute synchronized SGD solution
A_sync = sum(A_matrices)
b_sync = sum(A @ x for A, x in zip(A_matrices, x_m_stars))
x_sync = np.linalg.solve(A_sync, b_sync)

# Compute Local SGD solutions for varying K
x_local_solutions = []
for K in K_vals:
    C_matrices = [np.eye(2) - np.linalg.matrix_power(np.eye(2) - eta * A, K) for A in A_matrices]
    C = sum(C_matrices)
    c = sum(C @ x for C, x in zip(C_matrices, x_m_stars))
    x_local = np.linalg.solve(C, c)
    x_local_solutions.append(x_local)

# Function to plot contour ellipses
def plot_quadratic_contour(A, x_star, ax, color='gray', levels=[0.1, 0.2, 0.3], linestyle='--', label=None, alpha=0.3):
    theta = np.linspace(0, 2*np.pi, 200)
    eigvals, eigvecs = np.linalg.eigh(A)
    for level in levels:
        radius = np.sqrt(level)
        ellipse = np.array([
            radius * np.sqrt(1/eigvals[0]) * np.cos(theta),
            radius * np.sqrt(1/eigvals[1]) * np.sin(theta)
        ])
        ellipse = eigvecs @ ellipse
        ellipse = ellipse + x_star[:, None]
        ax.plot(ellipse[0], ellipse[1], color=color, lw=1, linestyle=linestyle, alpha=alpha, label=label)
        label = None  # label only once

# Plotting
fig, ax = plt.subplots(figsize=(8, 6))
ax.axhline(0, color='gray', lw=0.5)
ax.axvline(0, color='gray', lw=0.5)

# Define colors
well_color = '#1f77b4'    # Blue
outlier_color = '#d62728' # Red
sync_color = '#2ca02c'    # Green

# Plot all well-conditioned machines with transparent contours
for i in range(M - 1):
    ax.scatter(*x_m_stars[i], color=well_color, s=70, marker='^', alpha=0.9, label='well-conditioned' if i == 0 else None)
    plot_quadratic_contour(A_matrices[i], x_m_stars[i], ax, color=well_color, linestyle='-', levels=[0.1, 0.2, 0.3], alpha=0.1)

# Plot outlier machine with more prominent contour
ax.scatter(*x_m_stars[-1], color=outlier_color, label='outlier', s=120, marker='*')
plot_quadratic_contour(A_matrices[-1], x_m_stars[-1], ax, color=outlier_color, linestyle='-', levels=[0.1, 0.2, 0.3], alpha=0.2)

# Plot synchronized SGD solution
ax.scatter(*x_sync, color=sync_color, label=r'$x_\infty^{SGD}$', marker='X', s=100)

# Plot Local SGD solutions for varying K
colors = plt.cm.viridis(np.linspace(0.15, 0.85, len(K_vals)))
for k, x_k, color in zip(K_vals, x_local_solutions, colors):
    ax.scatter(*x_k, color=color, label=fr'$K={k}$', marker='o', s=70)

# Final formatting
ax.set_title("Curvature and Fixed Points: Local SGD vs Synchronized SGD")
ax.set_xlabel(r"$x_1$")
ax.set_ylabel(r"$x_2$")
ax.legend(loc='best', fontsize=9)
ax.axis('equal')
ax.grid(True)
plt.tight_layout()
plt.show()
